/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.scala.migration

import java.util

import org.apache.flink.api.common.accumulators.IntCounter
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.common.state._
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.api.scala.createTypeInformation
import org.apache.flink.api.scala.migration.CustomEnum.CustomEnum
import org.apache.flink.configuration.Configuration
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend
import org.apache.flink.runtime.state.memory.MemoryStateBackend
import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext, StateBackendLoader}
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.apache.flink.streaming.api.functions.source.SourceFunction
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.test.checkpointing.utils.SavepointMigrationTestBase
import org.apache.flink.testutils.migration.MigrationVersion
import org.apache.flink.util.Collector
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Assert, Ignore, Test}

import scala.util.{Failure, Try}

object StatefulJobWBroadcastStateMigrationITCase {

  @Parameterized.Parameters(name = "Migrate Savepoint / Backend: {0}")
  def parameters: util.Collection[(MigrationVersion, String)] = {
    util.Arrays.asList(
      (MigrationVersion.v1_5, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_5, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_6, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_6, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_7, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_7, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_8, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_8, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_9, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_9, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_10, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_10, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME),
      (MigrationVersion.v1_11, StateBackendLoader.MEMORY_STATE_BACKEND_NAME),
      (MigrationVersion.v1_11, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME))
  }

  // TODO to generate savepoints for a specific Flink version / backend type,
  // TODO change these values accordingly, e.g. to generate for 1.3 with RocksDB,
  // TODO set as (MigrationVersion.v1_3, StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME)
  // TODO Note: You should generate the savepoint based on the release branch instead of the master.
  val GENERATE_SAVEPOINT_VER: MigrationVersion = MigrationVersion.v1_9
  val GENERATE_SAVEPOINT_BACKEND_TYPE: String = StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME

  val NUM_ELEMENTS = 4
}

/**
  * ITCase for migration Scala state types across different Flink versions.
  */
@RunWith(classOf[Parameterized])
class StatefulJobWBroadcastStateMigrationITCase(
                                        migrationVersionAndBackend: (MigrationVersion, String))
  extends SavepointMigrationTestBase with Serializable {

  @Test
  @Ignore
  def testCreateSavepointWithBroadcastState(): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE match {
      case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
        env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
      case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
        env.setStateBackend(new MemoryStateBackend())
      case _ => throw new UnsupportedOperationException
    }

    lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
      "broadcast-state-1",
      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])

    lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
      "broadcast-state-2",
      BasicTypeInfo.STRING_TYPE_INFO,
      BasicTypeInfo.STRING_TYPE_INFO)

    env.setStateBackend(new MemoryStateBackend)
    env.enableCheckpointing(500)
    env.setParallelism(4)
    env.setMaxParallelism(4)

    val stream = env
      .addSource(
        new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
      .keyBy(
        new KeySelector[(Long, Long), Long] {
          override def getKey(value: (Long, Long)): Long = value._1
        }
      )
      .flatMap(new StatefulFlatMapper)
      .keyBy(
        new KeySelector[(Long, Long), Long] {
          override def getKey(value: (Long, Long)): Long = value._1
        }
      )

    val broadcastStream = env
      .addSource(
        new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
      .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)

    stream
      .connect(broadcastStream)
      .process(new TestBroadcastProcessFunction)
      .addSink(new AccumulatorCountingSink)

    executeAndSavepoint(
      env,
      s"src/test/resources/stateful-scala-with-broadcast" +
        s"-udf-migration-itcase-flink" +
        s"${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_VER}" +
        s"-${StatefulJobWBroadcastStateMigrationITCase.GENERATE_SAVEPOINT_BACKEND_TYPE}-savepoint",
      new Tuple2(
        AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
        StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS
      )
    )
  }

  @Test
  def testRestoreSavepointWithBroadcast(): Unit = {

    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    migrationVersionAndBackend._2 match {
      case StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME =>
        env.setStateBackend(new RocksDBStateBackend(new MemoryStateBackend()))
      case StateBackendLoader.MEMORY_STATE_BACKEND_NAME =>
        env.setStateBackend(new MemoryStateBackend())
      case _ => throw new UnsupportedOperationException
    }

    lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
      "broadcast-state-1",
      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
      BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])

    lazy val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
      "broadcast-state-2",
      BasicTypeInfo.STRING_TYPE_INFO,
      BasicTypeInfo.STRING_TYPE_INFO)

    env.setStateBackend(new MemoryStateBackend)
    env.enableCheckpointing(500)
    env.setParallelism(4)
    env.setMaxParallelism(4)

    val stream = env
      .addSource(
        new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedSource")
      .keyBy(
        new KeySelector[(Long, Long), Long] {
          override def getKey(value: (Long, Long)): Long = value._1
        }
      )
      .flatMap(new StatefulFlatMapper)
      .keyBy(
        new KeySelector[(Long, Long), Long] {
          override def getKey(value: (Long, Long)): Long = value._1
        }
      )

    val broadcastStream = env
      .addSource(
        new CheckpointedSource(4)).setMaxParallelism(1).uid("checkpointedBroadcastSource")
      .broadcast(firstBroadcastStateDesc, secondBroadcastStateDesc)

    val expectedFirstState: Map[Long, Long] =
      Map(0L -> 0L, 1L -> 1L, 2L -> 2L, 3L -> 3L)
    val expectedSecondState: Map[String, String] =
      Map("0" -> "0", "1" -> "1", "2" -> "2", "3" -> "3")

    stream
      .connect(broadcastStream)
      .process(new VerifyingBroadcastProcessFunction(expectedFirstState, expectedSecondState))
      .addSink(new AccumulatorCountingSink)

    restoreAndExecute(
      env,
      SavepointMigrationTestBase.getResourceFilename(
        s"stateful-scala-with-broadcast" +
          s"-udf-migration-itcase-flink${migrationVersionAndBackend._1}" +
          s"-${migrationVersionAndBackend._2}-savepoint"),
      new Tuple2(
        AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR,
        StatefulJobWBroadcastStateMigrationITCase.NUM_ELEMENTS)
    )
  }
}

class TestBroadcastProcessFunction
  extends KeyedBroadcastProcessFunction
    [Long, (Long, Long), (Long, Long), (Long, Long)] {

  lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
    "broadcast-state-1",
    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])

  val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
    "broadcast-state-2",
    BasicTypeInfo.STRING_TYPE_INFO,
    BasicTypeInfo.STRING_TYPE_INFO)

  @throws[Exception]
  override def processElement(
                               value: (Long, Long),
                               ctx: KeyedBroadcastProcessFunction
                                 [Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext,
                               out: Collector[(Long, Long)]): Unit = {

    out.collect(value)
  }

  @throws[Exception]
  override def processBroadcastElement(
                                        value: (Long, Long),
                                        ctx: KeyedBroadcastProcessFunction
                                          [Long, (Long, Long), (Long, Long), (Long, Long)]#Context,
                                        out: Collector[(Long, Long)]): Unit = {

    ctx.getBroadcastState(firstBroadcastStateDesc).put(value._1, value._2)
    ctx.getBroadcastState(secondBroadcastStateDesc).put(value._1.toString, value._2.toString)
  }
}

@SerialVersionUID(1L)
private object CheckpointedSource {
  var CHECKPOINTED_STRING = "Here be dragons!"
}

@SerialVersionUID(1L)
private class CheckpointedSource(val numElements: Int)
  extends SourceFunction[(Long, Long)] with CheckpointedFunction {

  private var isRunning = true
  private var state: ListState[CustomCaseClass] = _

  @throws[Exception]
  override def run(ctx: SourceFunction.SourceContext[(Long, Long)]) {
    ctx.emitWatermark(new Watermark(0))
    ctx.getCheckpointLock synchronized {
      var i = 0
      while (i < numElements) {
        ctx.collect(i, i)
        i += 1
      }
    }
    // don't emit a final watermark so that we don't trigger the registered event-time
    // timers
    while (isRunning) Thread.sleep(20)
  }

  def cancel() {
    isRunning = false
  }

  override def initializeState(context: FunctionInitializationContext): Unit = {
    state = context.getOperatorStateStore.getListState(
      new ListStateDescriptor[CustomCaseClass](
        "sourceState", createTypeInformation[CustomCaseClass]))
  }

  override def snapshotState(context: FunctionSnapshotContext): Unit = {
    state.clear()
    state.add(CustomCaseClass("Here be dragons!", 123))
  }
}

@SerialVersionUID(1L)
private object AccumulatorCountingSink {
  var NUM_ELEMENTS_ACCUMULATOR = classOf[AccumulatorCountingSink[_]] + "_NUM_ELEMENTS"
}

@SerialVersionUID(1L)
private class AccumulatorCountingSink[T] extends RichSinkFunction[T] {

  private var count: Int = 0

  @throws[Exception]
  override def open(parameters: Configuration) {
    super.open(parameters)
    getRuntimeContext.addAccumulator(
      AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, new IntCounter)
  }

  @throws[Exception]
  override def invoke(value: T) {
    count += 1
    getRuntimeContext.getAccumulator(
      AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR).add(1)
  }
}

class StatefulFlatMapper extends RichFlatMapFunction[(Long, Long), (Long, Long)] {

  private var caseClassState: ValueState[CustomCaseClass] = _
  private var caseClassWithNestingState: ValueState[CustomCaseClassWithNesting] = _
  private var collectionState: ValueState[List[CustomCaseClass]] = _
  private var tryState: ValueState[Try[CustomCaseClass]] = _
  private var tryFailureState: ValueState[Try[CustomCaseClass]] = _
  private var optionState: ValueState[Option[CustomCaseClass]] = _
  private var optionNoneState: ValueState[Option[CustomCaseClass]] = _
  private var eitherLeftState: ValueState[Either[CustomCaseClass, String]] = _
  private var eitherRightState: ValueState[Either[CustomCaseClass, String]] = _
  private var enumOneState: ValueState[CustomEnum] = _
  private var enumThreeState: ValueState[CustomEnum] = _

  override def open(parameters: Configuration): Unit = {
    caseClassState = getRuntimeContext.getState(
      new ValueStateDescriptor[CustomCaseClass](
        "caseClassState", createTypeInformation[CustomCaseClass]))
    caseClassWithNestingState = getRuntimeContext.getState(
      new ValueStateDescriptor[CustomCaseClassWithNesting](
        "caseClassWithNestingState", createTypeInformation[CustomCaseClassWithNesting]))
    collectionState = getRuntimeContext.getState(
      new ValueStateDescriptor[List[CustomCaseClass]](
        "collectionState", createTypeInformation[List[CustomCaseClass]]))
    tryState = getRuntimeContext.getState(
      new ValueStateDescriptor[Try[CustomCaseClass]](
        "tryState", createTypeInformation[Try[CustomCaseClass]]))
    tryFailureState = getRuntimeContext.getState(
      new ValueStateDescriptor[Try[CustomCaseClass]](
        "tryFailureState", createTypeInformation[Try[CustomCaseClass]]))
    optionState = getRuntimeContext.getState(
      new ValueStateDescriptor[Option[CustomCaseClass]](
        "optionState", createTypeInformation[Option[CustomCaseClass]]))
    optionNoneState = getRuntimeContext.getState(
      new ValueStateDescriptor[Option[CustomCaseClass]](
        "optionNoneState", createTypeInformation[Option[CustomCaseClass]]))
    eitherLeftState = getRuntimeContext.getState(
      new ValueStateDescriptor[Either[CustomCaseClass, String]](
        "eitherLeftState", createTypeInformation[Either[CustomCaseClass, String]]))
    eitherRightState = getRuntimeContext.getState(
      new ValueStateDescriptor[Either[CustomCaseClass, String]](
        "eitherRightState", createTypeInformation[Either[CustomCaseClass, String]]))
    enumOneState = getRuntimeContext.getState(
      new ValueStateDescriptor[CustomEnum](
        "enumOneState", createTypeInformation[CustomEnum]))
    enumThreeState = getRuntimeContext.getState(
      new ValueStateDescriptor[CustomEnum](
        "enumThreeState", createTypeInformation[CustomEnum]))
  }

  override def flatMap(in: (Long, Long), collector: Collector[(Long, Long)]): Unit = {
    caseClassState.update(CustomCaseClass(in._1.toString, in._2 * 2))
    caseClassWithNestingState.update(
      CustomCaseClassWithNesting(in._1, CustomCaseClass(in._1.toString, in._2 * 2)))
    collectionState.update(List(CustomCaseClass(in._1.toString, in._2 * 2)))
    tryState.update(Try(CustomCaseClass(in._1.toString, in._2 * 5)))
    tryFailureState.update(Failure(new RuntimeException))
    optionState.update(Some(CustomCaseClass(in._1.toString, in._2 * 2)))
    optionNoneState.update(None)
    eitherLeftState.update(Left(CustomCaseClass(in._1.toString, in._2 * 2)))
    eitherRightState.update(Right((in._1 * 3).toString))
    enumOneState.update(CustomEnum.ONE)
    enumOneState.update(CustomEnum.THREE)

    collector.collect(in)
  }
}

class VerifyingBroadcastProcessFunction(
                                         firstExpectedBroadcastState: Map[Long, Long],
                                         secondExpectedBroadcastState: Map[String, String])
  extends KeyedBroadcastProcessFunction
    [Long, (Long, Long), (Long, Long), (Long, Long)] {

  lazy val firstBroadcastStateDesc = new MapStateDescriptor[Long, Long](
    "broadcast-state-1",
    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]],
    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]])

  val secondBroadcastStateDesc = new MapStateDescriptor[String, String](
    "broadcast-state-2",
    BasicTypeInfo.STRING_TYPE_INFO,
    BasicTypeInfo.STRING_TYPE_INFO)

  @throws[Exception]
  override def processElement(
                               value: (Long, Long),
                               ctx: KeyedBroadcastProcessFunction
                                 [Long, (Long, Long), (Long, Long), (Long, Long)]#ReadOnlyContext,
                               out: Collector[(Long, Long)]): Unit = {

    var actualFirstState = Map[Long, Long]()

    import scala.collection.JavaConversions._
    for (entry <- ctx.getBroadcastState(firstBroadcastStateDesc).immutableEntries()) {
      val v = firstExpectedBroadcastState.get(entry.getKey).get
      Assert.assertEquals(v, entry.getValue)
      actualFirstState += (entry.getKey -> entry.getValue)
    }

    Assert.assertEquals(firstExpectedBroadcastState, actualFirstState)

    var actualSecondState = Map[String, String]()

    import scala.collection.JavaConversions._
    for (entry <- ctx.getBroadcastState(secondBroadcastStateDesc).immutableEntries()) {
      val v = secondExpectedBroadcastState.get(entry.getKey).get
      Assert.assertEquals(v, entry.getValue)
      actualSecondState += (entry.getKey -> entry.getValue)
    }

    Assert.assertEquals(secondExpectedBroadcastState, actualSecondState)
    out.collect(value)
  }

  @throws[Exception]
  override def processBroadcastElement(
                                        value: (Long, Long),
                                        ctx: KeyedBroadcastProcessFunction
                                          [Long, (Long, Long), (Long, Long), (Long, Long)]#Context,
                                        out: Collector[(Long, Long)]): Unit = {

    // do nothing
  }
}
